from glob import glob
from PIL import Image
import torch
from torch.utils.data.dataset import Dataset

import numpy as np
import scipy.misc
import os
# from PIL import Image
from torchvision import transforms
# import torch

class CUB():
    def __init__(self, root, is_train=True, data_len=None,transform=None, target_transform=None):
        self.root = root
        self.is_train = is_train
        self.transform = transform
        self.target_transform = target_transform
        img_txt_file = open(os.path.join(self.root, 'images.txt'))
        label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
        train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
        img_name_list = []
        for line in img_txt_file:
            img_name_list.append(line[:-1].split(' ')[-1])

        label_list = []
        for line in label_txt_file:
            label_list.append(int(line[:-1].split(' ')[-1]) - 1)

        train_test_list = []
        for line in train_val_file:
            train_test_list.append(int(line[:-1].split(' ')[-1]))

        train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
        test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]

        train_label_list = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
        test_label_list = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]

        if self.is_train:
            self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
                              train_file_list[:data_len]]
            self.train_label = train_label_list
        if not self.is_train:
            self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
                             test_file_list[:data_len]]
            self.test_label = test_label_list

    def __getitem__(self,index):
        if self.is_train:
            img, target = self.train_img[index], self.train_label[index]
        else:
            img, target = self.test_img[index], self.test_label[index]

        if len(img.shape) == 2:
            img = np.stack([img]*3,2)
        img = Image.fromarray(img,mode='RGB')
        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        if self.is_train:
            return len(self.train_label)
        else:
            return len(self.test_label)


filenames = glob('./data/oxford_pet/images/*.jpg')
classes = set()  
data = []
labels = []  
# Load the images and get the classnames from the image path 
def load_image(filename) :     
    img = Image.open(filename)
    img = img.convert('RGB')
    return img

for image in filenames:
    # print(image)
    class_name = image.rsplit("/", 1)[1].rsplit('_', 1)[0]
    # print(class_name)
    classes.add(class_name)
    img = load_image(image)
    data.append(img)
    labels.append(class_name)
# convert classnames to indices 
class2idx = {cl: idx for idx, cl in enumerate(classes)}         
labels = torch.Tensor(list(map(lambda x: class2idx[x], labels))).long()  
data = list(zip(data, labels))

# print(class2idx)
# print(labels.size(0))

class PetDataset(Dataset):
    "Dataset to serve individual images to our model"
    def __init__(self, data, transforms=None):
        self.data = data
        self.len = len(data)
        self.transforms = transforms

    def __getitem__(self, index):
        img, label = self.data[index]
        if self.transforms:
            img = self.transforms(img)
        return img, label

    def __len__(self):
        return self.len

class Databasket():
    "Helper class to ensure equal distribution of classes in both train and validation datasets"
    def __init__(self, data=data, num_cl=len(classes), val_split=0.2, train_transforms=None, val_transforms=None, resplit = False):
        class_values = [[] for x in range(num_cl)]
        # create arrays for each class type
        for d in data:
            class_values[d[1].item()].append(d)
        # print(data)
        # print(class_values)
        self.train_data = []
        self.val_data = []

        # put (1-val_split) of the images of each class into the train dataset
        # and val_split of the images into the validation dataset

        for class_dp in class_values:
            split_idx = int(len(class_dp)*(1-val_split))


            self.train_data += class_dp[:split_idx]
            if resplit:
                self.val_data += class_dp[split_idx::2]
            else:
                self.val_data += class_dp[split_idx:]

            
                # print(len(self.val_data))
                # exit()
        # print(len(self.train_data))
        # print(len(self.val_data))
        # exit()

        self.train_ds = PetDataset(self.train_data, transforms=train_transforms)
        self.val_ds = PetDataset(self.val_data, transforms=val_transforms)